# 使用图像增广技术来扩充数据集
import glob
import numpy as np
import os
import shutil
from utils import log_progress
import glob
import numpy as np
import os
import shutil
from utils import log_progress
import tensorflow as tf
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from keras.models import Sequential
from keras import optimizers
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
from keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array, array_to_img
#一定要加，不然出错
config = tf.compat.v1.ConfigProto(allow_soft_placement = True)
config.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config = config)

#读取数据集
IMG_DIM = (150, 150) # 图片的大小
train_files = glob.glob('F:\\dogsvscats\\training_data\\*')
# 加载图片，然后转换成数组,每一张都是(150,150,3)
train_imgs = [img_to_array(load_img(img, target_size=IMG_DIM)) for img in train_files]
train_imgs = np.array(train_imgs)# 列表转array
train_labels = [fn.split('\\')[3].split('.')[0].strip() for fn in train_files]# 下面的3，根据目录层次来决定

validation_files = glob.glob('F:\\dogsvscats\\validation_data\\*')
validation_imgs = [img_to_array(load_img(img, target_size=IMG_DIM)) for img in validation_files]
validation_imgs = np.array(validation_imgs)
validation_labels = [fn.split('\\')[3].split('.')[0].strip() for fn in validation_files]

# 把标签转换为0,1类别
le = LabelEncoder()
le.fit(train_labels)
train_labels_enc = le.transform(train_labels)
validation_labels_enc = le.transform(validation_labels)

# 定义一个数据生成器
train_datagen = ImageDataGenerator(rescale=1./255, zoom_range=0.3, rotation_range=50,
                                   width_shift_range=0.2, height_shift_range=0.2, shear_range=0.2,
                                   horizontal_flip=True, fill_mode='nearest')

val_datagen = ImageDataGenerator(rescale=1./255)

img_id = 2595
cat_generator = train_datagen.flow(train_imgs[img_id:img_id+1], train_labels[img_id:img_id+1],
                                   batch_size=1)
cat = [next(cat_generator) for i in range(0,5)]
fig, ax = plt.subplots(1,5, figsize=(16, 6))
print('Labels:', [item[1][0] for item in cat])
l = [ax[i].imshow(cat[i][0][0]) for i in range(0,5)]
plt.show()

# 对训练集和验证集进行数据增广
train_generator = train_datagen.flow(train_imgs, train_labels_enc, batch_size=30)
val_generator = val_datagen.flow(validation_imgs, validation_labels_enc, batch_size=20)


input_shape = (150, 150, 3)
model = Sequential()

model.add(Conv2D(16, kernel_size=(3, 3), activation='relu',
                 input_shape=input_shape))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(128, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(128, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.3))
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.3))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy',
              optimizer=optimizers.RMSprop(lr=1e-4),
              metrics=['accuracy'])
model.summary()
# steps_per_epoch 每一轮100此迭代
# history = model.fit(train_generator, steps_per_epoch=len(train_imgs)//30, epochs=100,
#                               validation_data=val_generator, validation_steps=len(validation_imgs)//20, verbose=1)
# 报错的原因是在model.fit的参数中，当同时出现了epochs和steps_per_epoch时，
# 每次epoch的遍历是对不同份的训练集和验证集进行遍历，那么就需要把训练集和验证集都拷贝N份(epochs份)。
# 可以使用repeat(epochs)对数据集进行拷贝epochs份


# f, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
# t = f.suptitle('CNN with Image Augmentation Performance', fontsize=12)
# f.subplots_adjust(top=0.85, wspace=0.3)
#
# epoch_list = list(range(1,101))
# ax1.plot(epoch_list, history.history['acc'], label='Train Accuracy')
# ax1.plot(epoch_list, history.history['val_acc'], label='Validation Accuracy')
# ax1.set_xticks(np.arange(0, 101, 10))
# ax1.set_ylabel('Accuracy Value')
# ax1.set_xlabel('Epoch')
# ax1.set_title('Accuracy')
# l1 = ax1.legend(loc="best")
#
# ax2.plot(epoch_list, history.history['loss'], label='Train Loss')
# ax2.plot(epoch_list, history.history['val_loss'], label='Validation Loss')
# ax2.set_xticks(np.arange(0, 101, 10))
# ax2.set_ylabel('Loss Value')
# ax2.set_xlabel('Epoch')
# ax2.set_title('Loss')
# l2 = ax2.legend(loc="best")
# plt.show()

# https://www.e-learn.cn/topic/3683954